import torch
import torch.nn as nn
from PIL import Image
import clip


class RewardModel(nn.Module):
    def __init__(self, embed_dim=768):  # ViT-L/14嵌入维度768
        super(RewardModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(embed_dim * 2, 1024),  # 两个embed拼起来
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )

    def forward(self, image_embed, text_embed):
        combining_embedding = torch.cat((image_embed, text_embed), dim=-1)
        return self.fc(combining_embedding)


def calculate_batch_scores(prompts, imgs, preprocess, clip_model, reward_model, device):
    """
    计算一个批次图文对的分数
    :param prompts: 文本列表，长度为 batch_size (e.g., ["text1", "text2", "text3", "text4"])
    :param imgs: 图像数组，形状为 (batch_size, H, W, C) (e.g., (4, 512, 512, 3))
    :param preprocess: 图像预处理函数
    :param clip_model: CLIP 模型
    :param reward_model: Reward Model
    :param device: 运行设备（CPU 或 GPU）
    :return: 每个图文对的分数，形状为 (batch_size, 1)
    """
    # 将图像和文本处理成模型输入格式
    imgs_tensor = torch.stack(
        [preprocess(Image.fromarray(img)).to(device) for img in imgs]
    )  # (batch_size, 3, 224, 224)
    text_tokens = clip.tokenize(prompts, truncate=True).to(device)  # (batch_size, max_token_len)

    # 模型设置为评估模式
    reward_model.eval()
    clip_model.eval()

    with torch.no_grad():
        # 获取文本和图像嵌入
        text_embeddings = clip_model.encode_text(text_tokens).float()  # (batch_size, embedding_dim)
        image_embeddings = clip_model.encode_image(imgs_tensor).float()  # (batch_size, embedding_dim)

        # 计算图文对的分数
        scores = reward_model(image_embeddings, text_embeddings)  # (batch_size, 1)

    return scores